###############################################################################
# Utilities: Set up server
###############################################################################

import socket

RECOGNIZED_WORKERS=["swarm", "node"]

def on_worker():
    return any(map(socket.gethostname().startswith, RECOGNIZED_WORKERS)) 


if on_worker():
    import numpyro
    numpyro.set_platform("gpu")
    from os import environ
    environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
    # environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'





###############################################################################
# generic pre-processing utilities
###############################################################################


import pickle
import jax.numpy as np
import jax
import os


def np_arrays(arrays): 
  return [np.array(ar) for ar in arrays]


def array_equal_list(a, b):
    return all([np.array_equal(i, j) for i,j in zip(a, b)])


def load_npz_arrays(fname):
    rez = np.load(fname)
    return [v for k,v in rez.items()]


###############################################################################
# generic display utilities
###############################################################################
def print_dict(dic):
    print("##################################")
    for k,v in dic.items():
        if isinstance(v, (tuple, list)):
            print(f"{k}: {v}")
        elif isinstance(v, (np.ndarray)):
            if v.ndim>1:
              print(f"{k}: \n{v}")
            else:
              print(f"{k}: {v}")
        else:
            print(f"{k}: {v}")
    print("##################################")

def print_params_shape(rez):
    from flax.core import FrozenDict
    for k, v in rez.items():
        if not isinstance(v, (dict, FrozenDict)):
            print(f"{k}: {v.shape}")
        else:
            print(f"{k}:")
            print_params_shape(v)


###############################################################################
# generic utilities
###############################################################################

def batch2D(func):
  """
  A function to automatically batch
  a function that expect 2D input

  Args:
      func (function): 
        A vmap compatible function that expects 2D input
  Returns:
    function:
      A version that is compatible to a batch of 2D inputs.
  """
  def _func(x, **kwargs):
    if x.ndim ==2:
      return func(x, **kwargs)
    elif x.ndim == 3:
      _func_temp = lambda x: func(x, **kwargs)
      return jax.vmap(_func_temp)(x)
    else: 
      raise AttributeError(f"Expects a 2D array or a batch of 2D array")
  return _func


def batch2D_class(func):
  def _func(self, x, **kwargs):
    some_func = lambda x,**kwargs: func(self, x, **kwargs)
    return batch2D(some_func)(x, **kwargs)
  return _func


def batch1D(func):
  """
  A function to automatically batch
  a function that expect 1D input

  Args:
      func (function): 
        A vmap compatible function that expects 1D input
  Returns:
    function:
      A version that is compatible to a batch of 1D inputs.
  """
  def _func(x, **kwargs):
    if x.ndim ==1:
      return func(x, **kwargs)
    elif x.ndim == 2:
      _func_temp = lambda x: func(x, **kwargs)
      return jax.vmap(_func_temp)(x)
    else: 
      raise AttributeError(f"Expects a 1D array or a batch of 1D array")
  return _func


def batch1D_class(func):
  def _func(self, x, **kwargs):
    some_func = lambda x,**kwargs: func(self, x, **kwargs)
    return batch1D(some_func)(x, **kwargs)
  return _func


def eye_3d(D,n):
    "get an n x D x D matrix of identities"
    return np.expand_dims(np.eye(D),axis=0).repeat(n,axis=0)

@batch2D
def get_sing_vals(L):
  return np.linalg.svd(L)[1]


@batch2D
def get_condition_number(L):
  sing_vals = get_sing_vals(L)**2
  return np.max(sing_vals)/np.min(sing_vals)


def smooth(x, alpha=0.01, conservative = False):
    def update(carry, x):
        y, i = carry
        if conservative:
          α = np.maximum(alpha, 1 / (i+1)) 
        else: 
          α = alpha
        y = y*(1-α)+α*x
        i+=1
        carry = (y, i) 
        return carry, y
    return jax.lax.scan(update, (0.0, 0), x)[1]


def is_list_tuple_array(x):
  if isinstance(x, (list, tuple, np.ndarray)):
    return True
  return False


def dump_objects(objects, fname, verbose, mode='wb', **kwargs):
  pickle.dump(obj=objects, file=open(fname, mode, **kwargs))
  if verbose:
    print(f"Files saved at {fname}")


def load_objects(fname, verbose, mode='rb', **kwargs):
  if verbose:
    print(f"Files loaded from: {fname}")
  return pickle.load(file=open(fname, mode, **kwargs))


def dict_get(_dict, list_of_keys):
  return list(map(_dict.get, list_of_keys))


def create_dir(dir_name):
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

def get_attribute(objects, name):
    if not isinstance(objects, (list, tuple)):
      objects = [objects]
    for obj in objects:
        if hasattr(obj, name):
            return getattr(obj, name)
    raise AttributeError(
            f"The {name} does not match the attributes in any of the "
            f"following objects {[obj for obj in objects]}.")

def mv(m, v): 
    return np.einsum("...ij, ...j -> ...i", m, v)

def mm(m, n):
    return np.einsum("...ij, ...jk -> ...ik", m, n)

def vtv(m):
    return np.einsum("...i, ...i -> ...", m, m)

